Skip to content

[LTX-2] Fix flash attention shard_map for sequence lengths not divisible by context mesh axis#363

Open
mbohlool wants to merge 1 commit intomainfrom
ltx2_pipeline_fix
Open

[LTX-2] Fix flash attention shard_map for sequence lengths not divisible by context mesh axis#363
mbohlool wants to merge 1 commit intomainfrom
ltx2_pipeline_fix

Conversation

@mbohlool
Copy link
Collaborator

Description:

When the sequence length (e.g., audio tokens) is not evenly divisible by the context mesh axis size, shard_map in _tpu_flash_attention raises a ValueError because it cannot partition the array evenly across devices.

For example, LTX-2 with 121 frames at 24 fps produces 126 audio latent tokens. On an 8-device context axis, 126 is not divisible by 8, causing the failure.

The existing _pad_data_for_flash already pads sequences for flash block-size alignment inside shard_map, but the shard_map itself requires even partitioning before entry.

This fix pads query/key/value sequence dimensions to the nearest multiple of the context mesh axis size before shard_map, and trims the output back to the original length afterward. Segment-ID masking inside wrap_flash_attention ensures padded positions do not affect attention results.

@mbohlool mbohlool requested a review from entrpn as a code owner March 24, 2026 16:59
@github-actions
Copy link

@mbohlool mbohlool requested a review from prishajain1 March 24, 2026 17:00
@mbohlool mbohlool changed the title Fix flash attention shard_map for sequence lengths not divisible by context mesh axis [LTX-2] Fix flash attention shard_map for sequence lengths not divisible by context mesh axis Mar 24, 2026
@mbohlool mbohlool force-pushed the ltx2_pipeline_fix branch from e22ec2c to b8296b2 Compare March 24, 2026 17:25
@mbohlool mbohlool force-pushed the ltx2_pipeline_fix branch from b8296b2 to 9697900 Compare March 25, 2026 18:37
@mbohlool
Copy link
Collaborator Author

@entrpn PTAL

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants